

import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import torch
import sys

import argparse

from scipy.stats import norm



parser = argparse.ArgumentParser()
parser.add_argument('--num', nargs = '+', default=1,type=int, required=False,
		help = 'List # you want to plot')



colorSel = ['tab:blue','tab:orange','tab:green','tab:red','tab:purple','tab:brown','tab:pink','tab:gray','tab:olive','tab:cyan','blue','orange','green','red','purple','brown','pink','gray','rosybrown','goldenrod','blueviolet','indigo','firebrick','khaki','teal']
nEQ = []
paramsDict = {}
#for trainConfig in ['simp0_mse', 'simp0_crossEntropy', 'simp1_mse', 'simp1_crossEntropy']:
for idx3, trainConfig in enumerate(['simp0_mse', 'simp1_mse', ]):
	trainSet = ['train1e7']#, 'train1e5', 'train1e6']#, 'train1e7']
	for idx, val in enumerate(trainSet):
		nEQ.append(torch.load('./results/pam4_%s_fc1_v2/nEQ_pam4_20dB_%s.pt'%(val,trainConfig)))
		for name, param in nEQ[-1].named_parameters():
			# create dict at first time
			if (idx == 0):
				paramsDict[name] = []
			param = param.to('cpu')
			param = param.detach()
			param = param.numpy()
			param = param.flatten()
			paramsDict[name].append(param)
	
	for idx, key in enumerate(paramsDict):
		plt.figure(key,figsize=(6,3))
		plt.grid(True)
		#print (paramsDict[key])
		#print (paramsDict)
		params = paramsDict[key]
		bins = np.histogram(np.array(params).flatten(), bins=50)[1]
		#print(bins)
		#sys.exit()
	
		for idx2, val in enumerate(trainSet):
			#print(idx*4+idx2)
			plt.hist(params[idx2], bins, density=True, color=colorSel[idx*4+idx2+idx3], label = key+'_'+trainConfig+' '+val, alpha=0.5,histtype='bar')
			mu, std = norm.fit(params[idx2]) 
			xmin,xmax = plt.xlim()
			x = np.linspace(xmin, xmax, 100)
			p = norm.pdf(x, mu, std)
			plt.plot(x,p, color=colorSel[idx*4+idx2+idx3])
	
		plt.legend(loc='best')



#plt.figure(0)
#plt.title('MOD:NRZ, Layer:10x32x32x2, train1e4')
#nEQ = torch.load('./results/nrz_train1e4_fc1/nEQ_nrz_10dB.pt')
#for name, param in nEQ.named_parameters():                
#   param = param.to('cpu')
#   param = param.detach()
#   param = param.flatten()
#   print(f"name: {name} params:\n{param}")                
#   plt.hist(param, label=name, bins=50, alpha=0.5)
#plt.xlim([-3, 3])
#plt.grid(True)
#plt.legend(loc='best')
#
#plt.figure(1)
#plt.title('MOD:NRZ, Layer:10x32x32x2, train1e5')
#nEQ = torch.load('./results/nrz_train1e5_fc1/nEQ_nrz_10dB.pt')
#for name, param in nEQ.named_parameters():                
#   param = param.to('cpu')
#   param = param.detach()
#   param = param.flatten()
#   print(f"name: {name} params:\n{param}")                
#   plt.hist(param, label=name, bins=50, alpha=0.5)
#plt.xlim([-3, 3])
#plt.grid(True)
#plt.legend(loc='best')
#
#plt.figure(2)
#plt.title('MOD:NRZ, Layer:10x32x32x2, train1e6')
#nEQ = torch.load('./results/nrz_train1e6_fc1/nEQ_nrz_10dB.pt')
#for name, param in nEQ.named_parameters():                
#   param = param.to('cpu')
#   param = param.detach()
#   param = param.flatten()
#   print(f"name: {name} params:\n{param}")                
#   plt.hist(param, label=name, bins=50, alpha=0.5)
#plt.xlim([-3, 3])
#plt.grid(True)
#plt.legend(loc='best')




plt.show()
